import time
import torch
from util.trainer.AverageMeter import AverageMeter
from util.trainer.accuracy import accuracy
from util.trainer.loss_and_top1_acc import loss_and_top1_acc
def validate(val_loader, model, criterion, local_rank, rank):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    loss_per_gpu = AverageMeter()
    top1_acc_per_gpu = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input_var = input.cuda(local_rank)
            target_var = target.cuda(local_rank)

            # if args.half:
            #     input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target_var)[0]
            loss_per_gpu.update(loss.item(), input.size(0))
            top1_acc_per_gpu.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    Loss, top1_acc = loss_and_top1_acc(loss_per_gpu,top1_acc_per_gpu,local_rank)


    if rank==0  or (rank is None):           
        print('valid:     \t' 
                'Loss {loss:.4f}\t'
                'Prec@1 {top1_acc:.3f}'.format(
            loss=Loss, top1_acc=top1_acc))

    # print(' * Prec@1 {top1.avg:.3f}'
    #       .format(top1=top1))

    return top1_acc